[YOLOv8] 生産ラインを流れるアヒルを追跡して数をかぞえてみました
1 はじめに
製造ビジネステクノロジー部の平内(SIN)です。
Ultralytics 社の YOLOv8は、最先端、高速、正確で非常に使いやすく設計された物体検出モデルです。
YOLOv8は、さまざまなオブジェクトの検出、インスタンスのセグメンテーション、画像分類、ポーズ推定などを処理することが可能ですが、トラックング(追跡)タスクについても対応しています。
今回は、このオブジェクト検出及び、トラッキングを使用して、生産ラインを流れるアヒルをカウントしてみました。
最初に、作成したデモをご確認下さい。
2 物体検出(ファインチューニング)
最初に、アヒルを検出するためモデル作成します。
YOLOv8のファインチューニングは、非常に簡単で、形式通りのデータを準備してmodel.train()を実行するだけです。
from ultralytics import YOLO
model = YOLO("yolov8l.pt")
model.train(data="dataset.yaml", epochs=15, batch=8, workers=4, degrees=90.0)
データセットの配置は、dataset.yamlで指定します。
dataset.yaml
train: /home/dataset/yolo/train/images
val: /home/dataset/yolo/valid/images
nc: 1
names: ["ahiru"]
実際のデータの配置は、以下のとおりです。
.
├── train
│ ├── images
│ │ ├── 00600.png
│ │ ├── 00601.png
・・・略・・・
│ │ ├── 02998.png
│ │ └── 02999.png
│ ├── labels
│ │ ├── 00600.txt
│ │ ├── 00601.txt
・・・略・・・
│ │ ├── 02998.txt
│ │ └── 02999.txt
│ └── labels.cache
└── valid
├── images
│ ├── 00000.png
│ ├── 00001.png
・・・略・・・
│ ├── 00598.png
│ └── 00599.png
├── labels
│ ├── 00000.txt
│ ├── 00001.txt
・・・略・・・
│ ├── 00598.txt
│ └── 00599.txt
└── labels.cache
6 directories, 6002 files
データセットは、対象オブジェクトを撮影し、Segment Anything Modelで切り出して、プログラムによる合成により大量生産しています。(今回は、3,000枚の画像と約24,000個のアノテーション)
カメラによる撮影
データセット生成
参考:
トレーニングの状況です。
__
epoch, train/box_loss, train/cls_loss, metrics/mAP50(B), metrics/mAP50-95(B),r/pg2
1, 1.0073, 0.62287, 0.92468, 0.42867,66444
2, 0.88396, 0.49335, 0.87443, 0.57074,12433
3, 0.82376, 0.46769, 0.99497, 0.70049,17341
4, 0.75457, 0.42457, 0.99496, 0.74066,01604
5, 0.68971, 0.39211, 0.98101, 0.71505,01472
6, 0.56326, 0.28422, 0.93799, 0.72821,00134
7, 0.53406, 0.26958, 0.995, 0.80915,01208
8, 0.49782, 0.25346, 0.99287, 0.89071,01076
9, 0.45277, 0.23768, 0.995, 0.86261,00944
10, 0.41017, 0.21778, 0.995, 0.91045,00812
11, 0.3922, 0.21385, 0.995, 0.94417,00068
12, 0.35645, 0.19367, 0.995, 0.91055,00548
13, 0.33934, 0.18195, 0.995, 0.92263,00416
14, 0.31723, 0.17222, 0.995, 0.96464,00284
15, 0.2961, 0.16354, 0.995, 0.97386,00152
3 トラッキング
作成したモデルでトラッキングしているコードです。
検出したオブジェクトのIDを確認し、画面の右に位置する時にリストアップしておき、中央を超えた時点で、カウンターをアップしています。
import cv2
from ultralytics import YOLO
COLORS = [
(255, 80, 0),
(255, 255, 0),
(255, 80, 100),
(255, 80, 255),
(255, 120, 255),
(155, 255, 255),
(155, 155, 255),
(155, 200, 200),
(155, 80, 155),
(200, 200, 200),
]
model = YOLO("./runs/detect/train4/weights/best.pt")
class Counter:
before_counting_id_list = [] # 画面の右側にあるカウントする前のIDリスト
def __init__(self, w, h):
self.counter = 0
self.w = w
self.h = h
def set(self, id, box):
# 対象オブジェクトのX座標の中心を取得
x1 = box[0]
x2 = box[2]
center = int(x1 + (x2 - x1) / 2)
# 画面の中央より右側にある場合
if center > int(self.w / 2):
# まだ、リストの損じしない場合、リストに追加
if not id in self.before_counting_id_list:
self.before_counting_id_list.append(id)
# 画面の中央より左側にある場合
if center < int(self.w / 2):
# IDにあれば、リストを削除してカウントアップする
if id in self.before_counting_id_list:
self.counter += 1
self.before_counting_id_list.remove(id)
def disp_counter(self, frame):
# 中央のラインを描画
cv2.line(
frame, (int(self.w / 2), 70), (int(self.w / 2), self.h - 20), (0, 0, 255), 2
)
# カウンターを描画
cv2.putText(
frame,
"COUNTER: {}".format(self.counter),
(230, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 255),
3,
)
# 検出したオブジェクトにラベルを表示する関数
def disp_label(frame, box, id):
color = COLORS[id % 10]
x1 = int(box[0])
y1 = int(box[1])
x2 = int(box[2])
y2 = int(box[3])
cv2.rectangle(
frame,
(x1, y1),
(x2, y2),
color,
2,
)
cv2.putText(
frame,
"ID:{}".format(id),
(x1, y1 - 15),
cv2.FONT_HERSHEY_SIMPLEX,
1,
color,
3,
)
def main() -> int:
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FPS, 5) # 強制的にFLSを下げる
w, h, fps = (
int(cap.get(x))
for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)
)
print("w:{} h:{} fps:{}".format(w, h, fps))
counter = Counter(w, h)
while True:
ret, frame = cap.read()
if not ret:
print("Video ERROR")
break
# 検出したオブジェクトを取得
results = model.track(frame, persist=True, conf=0.3)
# 検出した個々のオブジェクトを処理する
for box in results[0].boxes:
r = box.xyxy.tolist()
# トラッキングIDを取得
id = int(box.id) if box.id is not None else 0
# カウント処理
counter.set(id, r[0])
disp_label(frame, r[0], id)
# カウント表示
counter.disp_counter(frame)
frame = cv2.resize(frame, dsize=None, fx=1.5, fy=1.5)
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()
4 最後に
YOLOv8で物体検出すると、検出されたオブジェクトの固有IDが取得できるため、これを使用することで、各種のトラッキングタスクが処理可能になります。
ただし、追跡のためには、対象オブジェクトを駒落ちせずに確実に検出する必要があるため、物体検出モデルの精度は、比較的高いものが要求されます。
SAMを使用したデータセット作成では、非常に質の高い大量のデータが簡単に生成できるため、これを可能にしているとも言えそうです。
5 参考リンク